{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Lorentz ODE Solver",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ntE40GybB_cn",
        "colab_type": "text"
      },
      "source": [
        "# Lorentz ODE Solver in JAX\n",
        "Alex Alemi"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fyoHa_blbI71",
        "colab_type": "text"
      },
      "source": [
        "# Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GAFiL4V_kPE8",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import io\n",
        "import os\n",
        "from functools import partial\n",
        "import numpy as np\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "from jax import vmap, jit, grad, ops, lax, config\n",
        "from jax import random as jr\n",
        "\n",
        "import matplotlib as mpl\n",
        "import matplotlib.pyplot as plt\n",
        "import matplotlib.cm as cm\n",
        "from IPython.display import display_png\n",
        "\n",
        "mpl.rcParams['savefig.pad_inches'] = 0\n",
        "plt.style.use('seaborn-dark')\n",
        "%matplotlib inline"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vruUCSlrU_L7",
        "colab_type": "text"
      },
      "source": [
        "# Plotting Utilities\n",
        "\n",
        "These just provide fast, better antialiased line plotting than typical matplotlib plotting routines."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aTVqxdEQLZwM",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "@jit\n",
        "def drawline(im, x0, y0, x1, y1):\n",
        "  \"\"\"An implementation of Wu's antialiased line algorithm.\n",
        "  \n",
        "  This functional version was adapted from here:\n",
        "    https://en.wikipedia.org/wiki/Xiaolin_Wu's_line_algorithm\n",
        "  \"\"\"\n",
        "\n",
        "  ipart = lambda x: jnp.floor(x).astype('int32')\n",
        "  round_ = lambda x: ipart(x + 0.5).astype('int32')\n",
        "  fpart = lambda x: x - jnp.floor(x)\n",
        "  rfpart = lambda x: 1 - fpart(x)\n",
        "\n",
        "  def plot(im, x, y, c):\n",
        "    return ops.index_add(im, ops.index[x, y], c)\n",
        "\n",
        "  steep = jnp.abs(y1 - y0) > jnp.abs(x1 - x0)\n",
        "  cond_swap = lambda cond, x: lax.cond(cond, x, lambda x: (x[1], x[0]), x, lambda x: x)\n",
        "  \n",
        "  (x0, y0) = cond_swap(steep, (x0, y0))\n",
        "  (x1, y1) = cond_swap(steep, (x1, y1))\n",
        "  \n",
        "  (y0, y1) = cond_swap(x0 > x1, (y0, y1))\n",
        "  (x0, x1) = cond_swap(x0 > x1, (x0, x1))\n",
        "\n",
        "  dx = x1 - x0\n",
        "  dy = y1 - y0\n",
        "  gradient = jnp.where(dx == 0.0, 1.0, dy/dx)\n",
        "\n",
        "  # handle first endpoint\n",
        "  xend = round_(x0)\n",
        "  yend = y0 + gradient * (xend - x0)\n",
        "  xgap = rfpart(x0 + 0.5)\n",
        "  xpxl1 = xend # this will be used in main loop\n",
        "  ypxl1 = ipart(yend)\n",
        "\n",
        "  def true_fun(im):\n",
        "    im = plot(im, ypxl1, xpxl1, rfpart(yend) * xgap)\n",
        "    im = plot(im, ypxl1+1, xpxl1,  fpart(yend) * xgap)\n",
        "    return im\n",
        "  def false_fun(im):\n",
        "    im = plot(im, xpxl1, ypxl1  , rfpart(yend) * xgap)\n",
        "    im = plot(im, xpxl1, ypxl1+1,  fpart(yend) * xgap)\n",
        "    return im\n",
        "  im = lax.cond(steep, im, true_fun, im, false_fun)\n",
        "  \n",
        "  intery = yend + gradient\n",
        "\n",
        "  # handle second endpoint\n",
        "  xend = round_(x1)\n",
        "  yend = y1 + gradient * (xend - x1)\n",
        "  xgap = fpart(x1 + 0.5)\n",
        "  xpxl2 = xend  # this will be used in the main loop\n",
        "  ypxl2 = ipart(yend)\n",
        "  def true_fun(im):\n",
        "    im = plot(im, ypxl2  , xpxl2, rfpart(yend) * xgap)\n",
        "    im = plot(im, ypxl2+1, xpxl2,  fpart(yend) * xgap)\n",
        "    return im\n",
        "  def false_fun(im):\n",
        "    im = plot(im, xpxl2, ypxl2,  rfpart(yend) * xgap)\n",
        "    im = plot(im, xpxl2, ypxl2+1, fpart(yend) * xgap)\n",
        "    return im\n",
        "  im = lax.cond(steep, im, true_fun, im, false_fun)\n",
        "  \n",
        "  def true_fun(arg):\n",
        "    im, intery = arg\n",
        "    def body_fun(x, arg):\n",
        "      im, intery = arg\n",
        "      im = plot(im, ipart(intery), x, rfpart(intery))\n",
        "      im = plot(im, ipart(intery)+1, x, fpart(intery))\n",
        "      intery = intery + gradient\n",
        "      return (im, intery)\n",
        "    im, intery = lax.fori_loop(xpxl1+1, xpxl2, body_fun, (im, intery))\n",
        "    return (im, intery)\n",
        "  def false_fun(arg):\n",
        "    im, intery = arg\n",
        "    def body_fun(x, arg):\n",
        "      im, intery = arg\n",
        "      im = plot(im, x, ipart(intery), rfpart(intery))\n",
        "      im = plot(im, x, ipart(intery)+1, fpart(intery))\n",
        "      intery = intery + gradient\n",
        "      return (im, intery)\n",
        "    im, intery = lax.fori_loop(xpxl1+1, xpxl2, body_fun, (im, intery))\n",
        "    return (im, intery)\n",
        "  im, intery = lax.cond(steep, (im, intery), true_fun, (im, intery), false_fun)\n",
        "  \n",
        "  return im\n",
        "\n",
        "def img_adjust(data):\n",
        "  oim = np.array(data)\n",
        "  hist, bin_edges = np.histogram(oim.flat, bins=256*256)\n",
        "  bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2\n",
        "  cdf = hist.cumsum()\n",
        "  cdf = cdf / float(cdf[-1])\n",
        "  return np.interp(oim.flat, bin_centers, cdf).reshape(oim.shape)\n",
        "\n",
        "def imify(arr, vmin=None, vmax=None, cmap=None, origin=None):\n",
        "  arr = img_adjust(arr)\n",
        "  sm = cm.ScalarMappable(cmap=cmap)\n",
        "  sm.set_clim(vmin, vmax)\n",
        "  if origin is None:\n",
        "    origin = mpl.rcParams[\"image.origin\"]\n",
        "  if origin == \"lower\":\n",
        "    arr = arr[::-1]\n",
        "  rgba = sm.to_rgba(arr, bytes=True)\n",
        "  return rgba\n",
        "\n",
        "def plot_image(array, **kwargs):\n",
        "  f = io.BytesIO()\n",
        "  imarray = imify(array, **kwargs)\n",
        "  plt.imsave(f, imarray, format=\"png\")\n",
        "  f.seek(0)\n",
        "  dat = f.read()\n",
        "  f.close()\n",
        "  display_png(dat, raw=True)\n",
        "\n",
        "def pack_images(images, rows, cols):\n",
        "  shape = np.shape(images)\n",
        "  width, height, depth = shape[-3:]\n",
        "  images = np.reshape(images, (-1, width, height, depth))\n",
        "  batch = np.shape(images)[0]\n",
        "  rows = np.minimum(rows, batch)\n",
        "  cols = np.minimum(batch // rows, cols)\n",
        "  images = images[:rows * cols]\n",
        "  images = np.reshape(images, (rows, cols, width, height, depth))\n",
        "  images = np.transpose(images, [0, 2, 1, 3, 4])\n",
        "  images = np.reshape(images, [rows * width, cols * height, depth])\n",
        "  return images"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FFkdRUDR9cWD",
        "colab_type": "text"
      },
      "source": [
        "# Lorentz Dynamics\n",
        "\n",
        "Implement Lorentz' attractor"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aoSvqedskd0W",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "sigma = 10.\n",
        "beta = 8./3\n",
        "rho = 28.\n",
        "\n",
        "@jit\n",
        "def f(state, t):\n",
        "  x, y, z = state\n",
        "  return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tanYn8Cx9hUb",
        "colab_type": "text"
      },
      "source": [
        "# Runge Kutta Integrator"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ejuN_R7Km28v",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "@jit\n",
        "def rk4(ys, dt, N):\n",
        "  @jit\n",
        "  def step(i, ys):\n",
        "    h = dt\n",
        "    t = dt * i\n",
        "    k1 = h * f(ys[i-1], t)\n",
        "    k2 = h * f(ys[i-1] + k1/2., dt * i + h/2.)\n",
        "    k3 = h * f(ys[i-1] + k2/2., t + h/2.)\n",
        "    k4 = h * f(ys[i-1] + k3, t + h)\n",
        "    \n",
        "    ysi = ys[i-1] + 1./6 * (k1 + 2 * k2 + 2 * k3 + k4)\n",
        "    return ops.index_update(ys, ops.index[i], ysi)\n",
        "  return lax.fori_loop(1, N, step, ys)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i2UIxo3Z9PZ2",
        "colab_type": "text"
      },
      "source": [
        "# Solve and plot a single ODE Solution using jitted solver and plotter"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XvROzDrukzH_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "N = 40000\n",
        "\n",
        "# set initial condition\n",
        "state0 = jnp.array([1., 1., 1.])\n",
        "ys = jnp.zeros((N,) + state0.shape)\n",
        "ys = ops.index_update(ys, ops.index[0], state0)\n",
        "\n",
        "# solve for N steps\n",
        "ys = rk4(ys, 0.004, N).block_until_ready()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "t4k3UrtbM4jy",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# plotting size and region:\n",
        "xlim, zlim = (-20, 20), (0, 50)\n",
        "xN, zN = 800, 600\n",
        "\n",
        "# fast, jitted plotting function\n",
        "@partial(jax.jit, static_argnums=(2,3,4,5))\n",
        "def jplotter(xs, zs, xlim, zlim, xN, zN):\n",
        "  im = jnp.zeros((xN, zN))\n",
        "  xpixels = (xs - xlim[0])/(1.0 * (xlim[1] - xlim[0])) * xN\n",
        "  zpixels = (zs - zlim[0])/(1.0 * (zlim[1] - zlim[0])) * zN\n",
        "  def body_fun(i, im):\n",
        "    return drawline(im, xpixels[i-1], zpixels[i-1], xpixels[i], zpixels[i])\n",
        "  return lax.fori_loop(1, xpixels.shape[0], body_fun, im)\n",
        "\n",
        "im = jplotter(ys[...,0], ys[...,2], xlim, zlim, xN, zN)\n",
        "plot_image(im[:,::-1].T, cmap='magma')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JWkKc-mh7m9x",
        "colab_type": "text"
      },
      "source": [
        "# Parallel ODE Solutions with Pmap"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tlc8Y_pfOERv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "N_dev = jax.device_count()\n",
        "N = 4000\n",
        "\n",
        "# set some initial conditions for each replicate\n",
        "ys = jnp.zeros((N_dev, N, 3))\n",
        "state0 = jr.uniform(jr.PRNGKey(1), \n",
        "                    minval=-1., maxval=1.,\n",
        "                    shape=(N_dev, 3))\n",
        "state0 = state0 * jnp.array([18,18,1]) + jnp.array((0.,0.,10.))\n",
        "ys = ops.index_update(ys, ops.index[:, 0], state0)\n",
        "\n",
        "# solve each replicate in parallel using `pmap` of rk4 solver:\n",
        "ys = jax.pmap(rk4)(ys, \n",
        "                   0.004 * jnp.ones(N_dev), \n",
        "                   N * jnp.ones(N_dev, dtype=np.int32)\n",
        "                  ).block_until_ready()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_NdalA1qy1Fp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# parallel plotter using lexical closure and pmap'd core plotting function\n",
        "def pplotter(_xs, _zs, xlim, zlim, xN, zN):\n",
        "  N_dev = _xs.shape[0]\n",
        "  im = jnp.zeros((N_dev, xN, zN))\n",
        "  @jax.pmap\n",
        "  def plotfn(im, xs, zs):\n",
        "    xpixels = (xs - xlim[0])/(1.0 * (xlim[1] - xlim[0])) * xN\n",
        "    zpixels = (zs - zlim[0])/(1.0 * (zlim[1] - zlim[0])) * zN\n",
        "    def body_fun(i, im):\n",
        "      return drawline(im, xpixels[i-1], zpixels[i-1], xpixels[i], zpixels[i])\n",
        "    return lax.fori_loop(1, xpixels.shape[0], body_fun, im)\n",
        "  return plotfn(im, _xs, _zs)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vhZyGqHUYkKK",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "xlim, zlim = (-20, 20), (0, 50)\n",
        "xN, zN = 200, 150\n",
        "# above, plot ODE traces separately\n",
        "ims = pplotter(ys[...,0], ys[...,2], xlim, zlim, xN, zN)\n",
        "im = pack_images(ims[..., None], 4, 2)[..., 0]\n",
        "plot_image(im[:,::-1].T, cmap='magma')\n",
        "# below, plot combined ODE traces\n",
        "ims = pplotter(ys[...,0], ys[...,2], xlim, zlim, xN*4, zN*4)\n",
        "plot_image(jnp.sum(ims, axis=0)[:,::-1].T, cmap='magma')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "S6c5GWHBbkEX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}
